Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Refactor _prepare_model_input_tensors - take 2 #6164

Merged
merged 4 commits into from
Jul 17, 2024

Conversation

comaniac
Copy link
Collaborator

@comaniac comaniac commented Jul 6, 2024

This PR refactors _prepare_model_input_tensors. Specifically, we introduce ModelRunnerInputBuilder mainly for logic isolation and modularization. Specifically, ModelRunnerInputBuilder manages all processed input data, including token IDs, positions, sequence length, etc, in one place, and isolates the following logic:
The logic of inserting a new sequence group to input data, considering prefix caching, chunked prefill, sliding windows, etc.
3. The logic of preparing attention inputs.
4. The logic of preparing LoRA and multi-modal inputs.
5. The logic of creating on-device tensors for model inputs.

Note that the purpose of this PR is to enable follow-up refactoring and optimizations, so we don't expect an obvious performance improvement at this moment.

With this isolation, we could further have follow-up optimizations:

  1. Refactor AttentionMetadata to only include on-device tensors, and move all related logic from ModelRunnerInputBuilder.
  2. Remove the loop for seq_id in seq_ids in ModelRunnerInputBuilder._add_seq_group() by leveraging tensor processing.
  3. Parallelize the loop for seq_group_metadata in seq_group_metadata_list.
  4. and more.

@comaniac comaniac marked this pull request as draft July 6, 2024 01:05
@comaniac comaniac force-pushed the prepare_input_v2 branch 2 times, most recently from c1aa929 to 4b61a55 Compare July 7, 2024 21:22
multi_modal_kwargs = MultiModalInputs.batch(
self.multi_modal_inputs_list, device=runner.device)

return self.model_input_cls(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have a to(device) API? before this call, all the prepared input lives in cpu.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you suggest like the following?

        builder = ModelInputForGPUBuilder(...)
        for seq_group_metadata in seq_group_metadata_list:
            builder.add_seq_group(seq_group_metadata)
        model_input = builder.build(self) # All tensors are on CPU.
        return model_input.to(device) # Move to GPU.

This would introduce some overheads as we create tensors on CPU and then move to GPU, instead of directly creating them on GPU.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this would simplify the control plane communication.

actually, if you directly creating them on GPU, pytorch still needs to prepare them in cpu, and then move it to GPU together. I don't think there would be significant overhead here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following code snippets are essentially the same:

data = list(range(1000))
gpu_tensor = torch.tensor(data, dtype=torch.int, device="cuda")
data = list(range(1000))
cpu_tensor = torch.tensor(data, dtype=torch.int, device="cpu")
gpu_tensor = cpu_tensor.to("cuda")

In the first code, pytorch still needs to create a cpu array to hold the Python data structure, and then move it to GPU.

If we do the cpu-gpu movement before we broadcast it to the rest workers, they will wait for the broadcast.

timeline:

driver worker: |  prepare input in cpu | move to GPU | broadcast |
rest workers:  |    wait for broadcast                         |

If we do the cpu-gpu movement after we broadcast, the driver worker sends cpu data (which can be fast because we have shared memory transport), and all workers move data to gpu respectively.

timeline:

driver worker: |  prepare input in cpu | broadcast | move to GPU |
rest workers:  |    wait for broadcast           | move to GPU |

Because these data are usually quite small, I suppose the second approach would be faster. Broadcasting small cpu data would be easier. Broadcasting small GPU data with NCCL is not good in performance.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm...make sense since we will have a Python list anyways. I'll think about that and add this API if no other concerns. Thanks.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've basically done the first round of refactoring and feel this is doable, but I'd prefer to defer this until #6052 is merged, because I plan to add .to() to the following data classes:

  1. ModelInputForGPU, which will implicitly call attn_metadata.to(...) (the next point)
  2. AttentionMetadata <- conflict

@comaniac comaniac force-pushed the prepare_input_v2 branch 2 times, most recently from b37aeab to 546504b Compare July 9, 2024 22:23
@comaniac
Copy link
Collaborator Author

comaniac commented Jul 10, 2024

@rkooo567 @zhuohan123 @simon-mo @WoosukKwon @youkaichao @LiuXiaoxuanPKU I've done the first round of refactoring:

  1. The attention unrelated logic (tokens, sequence length, LoRA, MM, etc) remains in prepare_input.
  2. Keep prefill and decode logic together.
  3. Attention specific logic such as block table, slot mapping are moved to attention metadata builder.
  4. Flash attention and FlashInfer metadata builder are self-contained.
  5. xFormers / ROCmFlashAttention / BlockSparseAttention metadata builder share the same utility functions.

This PR is ready for review. I'll wait for CI to be green first and then rebase to resolve the conflict.

Remaining concern that could potentially be addressed in this PR: The arguments of attn_metadat_builcer.add_seq_group() is ugly. One reason is we have to compute sliding window sequence length outside of the attention metadata (because sequence length is common). However, we also need the original sequence length to compute block table and slot mapping inside the attention metadata.

Follow-up PRs: Move more attention related logic, such as dummy inputs for CUDA graph capturing and pre-/post-processing logic in forward.

@comaniac comaniac marked this pull request as ready for review July 10, 2024 22:22
@comaniac comaniac changed the title [Draft][Core] Refactor _prepare_model_input_tensors - take 2 [Core] Refactor _prepare_model_input_tensors - take 2 Jul 15, 2024
Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good (especially given tests are passing). Just a few comments for

  • comments, docstring, and typing for readability.
  • I think runner should be passed as weakref everywhere.

vllm/attention/backends/abstract.py Outdated Show resolved Hide resolved
vllm/worker/model_runner.py Show resolved Hide resolved
vllm/attention/backends/flash_attn.py Show resolved Hide resolved
vllm/worker/model_runner.py Show resolved Hide resolved
vllm/attention/backends/utils.py Show resolved Hide resolved

# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is compute_slot_mapping_start_idx splitted from compute_slot_mapping ? Seems like they are used together?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah the only reason is if we combine these 2 then the argument list would be more ugly:

compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
    seq_len, context_len,
    self.block_size,
    seq_group_metadata.block_tables,
    is_prompt, query_len, context_len, self.sliding_window, self.use_v2_block_manager)

I don't have preference tho so if you prefer to combine them I could do that. Plz let me know

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to combine it if you don't mind! Agreed there are too many args, but I think it is better than having 2 calls depending on each other (imo it is harder to use). But it is nit!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I'll merge this one and do it in the next PR.

vllm/worker/model_runner.py Show resolved Hide resolved
vllm/worker/model_runner.py Outdated Show resolved Hide resolved
vllm/attention/backends/flashinfer.py Show resolved Hide resolved
vllm/worker/model_runner.py Outdated Show resolved Hide resolved
@rkooo567
Copy link
Collaborator

feel free to merge it after addressing comments!

@comaniac comaniac added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 16, 2024
@comaniac comaniac merged commit 2fa4623 into vllm-project:main Jul 17, 2024
87 of 88 checks passed
@comaniac comaniac deleted the prepare_input_v2 branch July 17, 2024 16:37
fialhocoelho pushed a commit to opendatahub-io/vllm that referenced this pull request Jul 19, 2024
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 24, 2024
gnpinkert pushed a commit to gnpinkert/vllm that referenced this pull request Jul 26, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants